{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "词汇表大小: 28785\n",
      "词汇示例(word to index): {'<pad>': 0, '<sos>': 1, '<eos>': 2, 'the': 3, 'apple': 11505}\n"
     ]
    }
   ],
   "source": [
    "from torchtext.datasets import WikiText2 # 导入WikiText2\n",
    "from torchtext.data.utils import get_tokenizer # 导入Tokenizer分词工具\n",
    "from torchtext.vocab import build_vocab_from_iterator # 导入Vocabulary工具\n",
    "tokenizer = get_tokenizer(\"basic_english\") # 定义数据预处理所需的tokenizer\n",
    "train_iter = WikiText2(split='train') # 加载WikiText2数据集的训练部分\n",
    "# 定义一个生成器函数，用于将数据集中的文本转换为tokens\n",
    "def yield_tokens(data_iter):\n",
    "    for item in data_iter:\n",
    "        yield tokenizer(item)\n",
    "# 创建词汇表，包括特殊tokens：\"<pad>\", \"<sos>\", \"<eos>\"\n",
    "vocab = build_vocab_from_iterator(yield_tokens(train_iter), \n",
    "                                  specials=[\"<pad>\", \"<sos>\", \"<eos>\"])\n",
    "vocab.set_default_index(vocab[\"<pad>\"])\n",
    "\n",
    "# 打印词汇表信息\n",
    "print(\"词汇表大小:\", len(vocab))\n",
    "print(\"词汇示例(word to index):\", \n",
    "      {word: vocab[word] for word in [\"<pad>\", \"<sos>\", \"<eos>\", \"the\", \"apple\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample 1:\n",
      "Input Data:  tensor([   1, 9209,    4,  419,   37,  181,  860,    2])\n",
      "Target Data:  tensor([    1,    67,  1734,  1633,   124,     4, 13818,   181,     5,   419,\n",
      "           76,   181,   860,     2])\n",
      "--------------------------------------------------\n",
      "Sample 2:\n",
      "Input Data:  tensor([   1,   67, 1734,  426,    4, 6733,   20, 4168,    5,  188,  115,  181,\n",
      "         289,  860,    2])\n",
      "Target Data:  tensor([   1,   67, 1734,   33, 1976,  820, 1703,    5,   67,  115,  639,  181,\n",
      "        6108, 4280,    5,    2])\n",
      "--------------------------------------------------\n",
      "Sample 3:\n",
      "Input Data:  tensor([   1,  188,   26,    3, 1508,  142,  805,  860,    2])\n",
      "Target Data:  tensor([   1, 8943, 6421,   11, 1508, 1792,   50, 3627,   20,    3, 1092, 1406,\n",
      "           5,    2])\n",
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "import torch #导入torch\n",
    "from torch.utils.data import Dataset #导入Dataset\n",
    "\n",
    "class ChatDataset(Dataset):\n",
    "    def __init__(self, file_path, tokenizer, vocab):\n",
    "        self.tokenizer = tokenizer #分词器\n",
    "        self.vocab = vocab #词汇表\n",
    "        self.input_data, self.target_data = self.load_and_process_data(file_path)\n",
    "    def load_and_process_data(self, file_path):        \n",
    "        with open(file_path, \"r\") as f:\n",
    "            lines = f.readlines() # 打开文件，读取每一行数据\n",
    "        input_data, target_data = [], []\n",
    "        for i, line in enumerate(lines):\n",
    "            if line.startswith(\"User:\"): # 移除 \"User: \" 前缀，构建输入序列\n",
    "                tokens = self.tokenizer(line.strip()[6:])  \n",
    "                tokens = [\"<sos>\"] + tokens + [\"<eos>\"]\n",
    "                indices = [self.vocab[token] for token in tokens]\n",
    "                input_data.append(torch.tensor(indices, dtype=torch.long))\n",
    "            elif line.startswith(\"AI:\"): # 移除 \"AI: \" 前缀，构建目标序列\n",
    "                tokens = self.tokenizer(line.strip()[4:])  \n",
    "                tokens = [\"<sos>\"] + tokens + [\"<eos>\"]\n",
    "                indices = [self.vocab[token] for token in tokens]\n",
    "                target_data.append(torch.tensor(indices, dtype=torch.long))\n",
    "        return input_data, target_data\n",
    "    def __len__(self): #数据集长度\n",
    "        return len(self.input_data) \n",
    "    def __getitem__(self, idx): #根据索引获取数据样本\n",
    "        return self.input_data[idx], self.target_data[idx] \n",
    "\n",
    "file_path = \"chat.txt\" # 加载chat.txt数据集\n",
    "chat_dataset = ChatDataset(file_path, tokenizer, vocab)\n",
    "\n",
    "for i in range(3): # 打印几个样本数据\n",
    "    input_sample, target_sample = chat_dataset[i]\n",
    "    print(f\"Sample {i + 1}:\")\n",
    "    print(\"Input Data: \", input_sample)\n",
    "    print(\"Target Data: \", target_sample)\n",
    "    print(\"-\" * 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input batch tensor size: torch.Size([2, 17])\n",
      "Target batch tensor size: torch.Size([2, 17])\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader # 导入Dataloader\n",
    "# 定义pad_sequence函数，用于将一批序列补齐到相同长度\n",
    "def pad_sequence(sequences, padding_value=0, length=None):\n",
    "    # 计算最大序列长度，如果length参数未提供，则使用输入序列中的最大长度\n",
    "    max_length = max(len(seq) for seq in sequences) if length is None else length    \n",
    "    # 创建一个具有适当形状的全零张量，用于存储补齐后的序列\n",
    "    result = torch.full((len(sequences), max_length), padding_value, dtype=torch.long)    \n",
    "    # 遍历序列，将每个序列的内容复制到结果张量中\n",
    "    for i, seq in enumerate(sequences):\n",
    "        end = len(seq)\n",
    "        result[i, :end] = seq[:end]\n",
    "    return result\n",
    "\n",
    "# 定义collate_fn函数，用于将一个批次的数据整理成适当的形状\n",
    "def collate_fn(batch):\n",
    "    # 从批次中分离源序列和目标序列\n",
    "    sources, targets = zip(*batch)    \n",
    "    # 计算批次中的最大序列长度\n",
    "    max_length = max(max(len(s) for s in sources), max(len(t) for t in targets))    \n",
    "    # 使用pad_sequence函数补齐源序列和目标序列\n",
    "    sources = pad_sequence(sources, padding_value=vocab[\"<pad>\"], length=max_length)\n",
    "    targets = pad_sequence(targets, padding_value=vocab[\"<pad>\"], length=max_length)    \n",
    "    # 返回补齐后的源序列和目标序列\n",
    "    return sources, targets\n",
    "\n",
    "# 创建Dataloader\n",
    "batch_size = 2\n",
    "chat_dataloader = DataLoader(chat_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)\n",
    "\n",
    "# 检查Dataloader输出\n",
    "for input_batch, target_batch in chat_dataloader:\n",
    "    print(\"Input batch tensor size:\", input_batch.size())\n",
    "    print(\"Target batch tensor size:\", target_batch.size())\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from GPT_Model import GPT #导入GPT模型的类（这是我们自己制作的）\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model = GPT(len(vocab), max_seq_len=256, n_layers=6).to(device) #创建模型示例\n",
    "model.load_state_dict(torch.load('trained_model_2023-05-05_14-08-24.pt')) #加载模型\n",
    "# model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/huangjia/Documents/02_NLP/70 GeekTimeNLP/10 ChatGPT/GPT_2.py:15: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525551200/work/aten/src/ATen/native/cuda/Indexing.cu:1435.)\n",
      "  scores.masked_fill_(attn_mask, -1e9)\n",
      "/home/huangjia/anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py:197: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525551200/work/aten/src/ATen/native/cuda/Indexing.cu:1435.)\n",
      "  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0020, cost = 1.975874\n",
      "Epoch: 0040, cost = 0.021781\n",
      "Epoch: 0060, cost = 0.619990\n",
      "Epoch: 0080, cost = 0.777577\n",
      "Epoch: 0100, cost = 0.004273\n"
     ]
    }
   ],
   "source": [
    "import torch.nn as nn #导入nn\n",
    "import torch.optim as optim #导入优化器\n",
    "criterion = nn.CrossEntropyLoss(ignore_index=vocab[\"<pad>\"]) #损失函数\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.0001) # 优化器\n",
    "for epoch in range(100): # 开始训练\n",
    "    for batch_idx, (input_batch, target_batch) in enumerate(chat_dataloader):       \n",
    "        optimizer.zero_grad()  # 梯度清零        \n",
    "        input_batch, target_batch = input_batch.to(device), target_batch.to(device) #移动到设备               \n",
    "        outputs = model(input_batch) # 前向传播，计算模型输出       \n",
    "        loss = criterion(outputs.view(-1, len(vocab)), target_batch.view(-1)) # 计算损失           \n",
    "        loss.backward() # 反向传播        \n",
    "        optimizer.step() # 更新参数    \n",
    "    if (epoch + 1) % 20 == 0: # 每200个epoch打印一次损失值\n",
    "        print(f\"Epoch: {epoch + 1:04d}, cost = {loss:.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_text_beam_search(model, input_str, max_len=50, beam_width=5):\n",
    "    model.eval()  # 将模型设置为评估（测试）模式，关闭dropout和batch normalization等训练相关的层\n",
    "    # 将输入字符串中的每个token 转换为其在词汇表中的索引\n",
    "    input_tokens = [vocab[token] for token in input_str]\n",
    "    # 创建一个列表，用于存储候选序列\n",
    "    candidates = [(input_tokens, 0.0)]\n",
    "    with torch.no_grad():  # 禁用梯度计算，以节省内存并加速测试过程\n",
    "        for _ in range(max_len):  # 生成最多max_len个tokens\n",
    "            new_candidates = []\n",
    "            for candidate, candidate_score in candidates:\n",
    "                inputs = torch.LongTensor(candidate).unsqueeze(0).to(device)\n",
    "                outputs = model(inputs) # 输出 logits形状为[1, len(output_tokens), vocab_size]\n",
    "                logits = outputs[:, -1, :] # 只关心最后一个时间步（即最新生成的token）的logits\n",
    "                # 将<pad>标记的得分设置为一个很大的负数，以避免选择它\n",
    "                logits[0, vocab[\"<pad>\"]] = -1e9 # 不是这个原因，注意不认识的词汇都变成0\n",
    "                # 找到具有最高分数的前beam_width个tokens\n",
    "                scores, next_tokens = torch.topk(logits, beam_width, dim=-1)\n",
    "                final_results = []\n",
    "                for score, next_token in zip(scores.squeeze(), next_tokens.squeeze()):\n",
    "                    new_candidate = candidate + [next_token.item()]\n",
    "                    new_score = candidate_score - score.item()  # 使用负数，因为我们需要降序排列\n",
    "                    if next_token.item() == vocab[\"<eos>\"]:\n",
    "                        # 如果生成的token是EOS（结束符），将其添加到最终结果中\n",
    "                        final_results.append((new_candidate, new_score))\n",
    "                    else:\n",
    "                        # 将新生成的候选序列添加到新候选列表中\n",
    "                        new_candidates.append((new_candidate, new_score))\n",
    "            # 从新候选列表中选择得分最高的beam_width个序列\n",
    "            candidates = sorted(new_candidates, key=lambda x: x[1])[:beam_width]\n",
    "    # 选择得分最高的候选序列\n",
    "    best_candidate, _ = sorted(candidates, key=lambda x: x[1])[0]\n",
    "    # 将输出 token 转换回文本字符串\n",
    "    output_str = \" \".join([vocab.get_itos()[token] for token in best_candidate])\n",
    "    return output_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated text: hi , how are you ? thank you , depicting a suitable cavity ? prefer unknowingly a distinct indicated that link vocalists - very common starlings prefer unknowingly a distinct indicated that link vocalists - very common starlings prefer unknowingly a distinct indicated that link vocalists - very common starlings prefer unknowingly a distinct indicated horsepower\n"
     ]
    }
   ],
   "source": [
    "input_str = \"what is the weather like today ?\"\n",
    "input_str = \"hi , how are you ?\"\n",
    "# input_str = \"hi , what is you name ?\"\n",
    "\n",
    "generated_text = generate_text_beam_search(model, input_str.split())\n",
    "print(\"Generated text:\", generated_text)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
